
import torch


def collate_fn(batch):
    imgs, labels, indexs = zip(*batch)
    labels = [int(k) for k in labels]
    labels = torch.tensor(labels, dtype=torch.int64)
    indexs = [int(k) for k in indexs]
    indexs = torch.tensor(indexs, dtype=torch.int64)
    return torch.stack(imgs, dim=0), labels, indexs
